import logging
import pprint

from constants import causes, csrs, csrs32
from shared_utils import InstrDict, instr_dict_2_extensions

pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")


def make_chisel(instr_dict: InstrDict, spinal_hdl: bool = False):

    chisel_names = ""
    cause_names_str = ""
    csr_names_str = ""
    for i in instr_dict:
        if spinal_hdl:
            chisel_names += f'  def {i.upper().replace(".","_"):<18s} = M"b{instr_dict[i]["encoding"].replace("-","-")}"\n'
        # else:
        #     chisel_names += f'  def {i.upper().replace(".","_"):<18s} = BitPat("b{instr_dict[i]["encoding"].replace("-","?")}")\n'
    if not spinal_hdl:
        extensions = instr_dict_2_extensions(instr_dict)
        for e in extensions:
            if "rv64_" in e:
                e_format = e.replace("rv64_", "").upper() + "64"
            elif "rv32_" in e:
                e_format = e.replace("rv32_", "").upper() + "32"
            elif "rv_" in e:
                e_format = e.replace("rv_", "").upper()
            else:
                e_format = e.upper()
            chisel_names += f'  val {e_format+"Type"} = Map(\n'
            for instr_name, instr in instr_dict.items():
                if instr["extension"][0] == e:
                    tmp_instr_name = '"' + instr_name.upper().replace(".", "_") + '"'
                    chisel_names += f'   {tmp_instr_name:<18s} -> BitPat("b{instr["encoding"].replace("-","?")}"),\n'
            chisel_names += "  )\n"

    for num, name in causes:
        cause_names_str += f'  val {name.lower().replace(" ","_")} = {hex(num)}\n'
    cause_names_str += """  val all = {
    val res = collection.mutable.ArrayBuffer[Int]()
"""
    for num, name in causes:
        cause_names_str += f'    res += {name.lower().replace(" ","_")}\n'
    cause_names_str += """    res.toArray
  }"""

    for num, name in csrs + csrs32:
        csr_names_str += f"  val {name} = {hex(num)}\n"
    csr_names_str += """  val all = {
    val res = collection.mutable.ArrayBuffer[Int]()
"""
    for num, name in csrs:
        csr_names_str += f"""    res += {name}\n"""
    csr_names_str += """    res.toArray
  }
  val all32 = {
    val res = collection.mutable.ArrayBuffer(all:_*)
"""
    for num, name in csrs32:
        csr_names_str += f"""    res += {name}\n"""
    csr_names_str += """    res.toArray
  }"""

    with open(
        "inst.spinalhdl" if spinal_hdl else "inst.chisel", "w", encoding="utf-8"
    ) as chisel_file:
        chisel_file.write(
            f"""
/* Automatically generated by parse_opcodes */
object Instructions {{
{chisel_names}
}}
object Causes {{
{cause_names_str}
}}
object CSRs {{
{csr_names_str}
}}
"""
        )
